package cn.pomelo.ws.endpoint;

import cn.pomelo.ws.listener.OpenAiWebSocketEventSourceListener;
import com.unfbx.chatgpt.OpenAiStreamClient;
import com.unfbx.chatgpt.entity.chat.BaseMessage;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
import com.unfbx.chatgpt.interceptor.OpenAILogger;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.logging.HttpLoggingInterceptor;
import org.springframework.stereotype.Component;

import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

@Slf4j
@Component
@ServerEndpoint("/chat/{uid}")
public class WebSocketServer {


    // 每一个连接都会有一个 session
    private Session session;
    // 当前连接的用户的id
    private String uid;

    private final static Map<String, WebSocketServer> SERVERS = new ConcurrentHashMap<>();

    /**
     * 客户端和服务端建立连接
     *
     * @param session
     * @param uid
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("uid") String uid) {
        this.session = session;
        this.uid = uid;
        if (SERVERS.containsKey(uid)) {
            SERVERS.replace(uid, this);
        } else {
            SERVERS.put(uid, this);
        }
        log.info("用户 {} 连接成功", uid);
    }

    /**
     * 连接断开
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        log.info("连接id：{}，错误原因：{}", this.uid, error.getMessage());
        error.printStackTrace();
    }

    @OnClose
    public void onClose() {
        SERVERS.remove(this.uid);
        log.info("[连接ID:{}] 断开连接", uid);
    }


    /**
     * 客户端 给 服务端发送消息
     * @param msg
     */
    @OnMessage
    public void onMessage(String msg) {
        log.info("连接id：{}，收到消息：{}", this.uid, msg);
        ask(msg);
    }

    private void ask(String question) {
        // 将当前对象作为参数传递给 监听器，在监听器中 使用session对象发送消息
        // 监听器监听 openai 的响应
        OpenAiWebSocketEventSourceListener listener = new OpenAiWebSocketEventSourceListener(this.session);
        // 构建 OpenAI 连接
        List<Message> messages = new ArrayList<>();
        HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor(new OpenAILogger());
        httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
        OkHttpClient okHttpClient = new OkHttpClient
                .Builder()
                .addInterceptor(httpLoggingInterceptor)//自定义日志
                .connectTimeout(30, TimeUnit.SECONDS)//自定义超时时间
                .writeTimeout(30, TimeUnit.SECONDS)//自定义超时时间
                .readTimeout(30, TimeUnit.SECONDS)//自定义超时时间
                .build();
        OpenAiStreamClient client = OpenAiStreamClient.builder()
                .apiKey(List.of("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"))
                .okHttpClient(okHttpClient)
               .apiHost("https://ycdl.ysywy.shop/")
                .build();
        Message message = Message.builder().role(BaseMessage.Role.USER).content(question).build();
        messages.add(message);
        //聊天模型：gpt-3.5
        ChatCompletion chatCompletion = ChatCompletion.builder().messages(messages).build();
        client.streamChatCompletion(chatCompletion, listener);
    }
}
